// David Eberly, Geometric Tools, Redmond WA 98052
// Copyright (c) 1998-2019
// Distributed under the Boost Software License, Version 1.0.
// http://www.boost.org/LICENSE_1_0.txt
// http://www.geometrictools.com/License/Boost/LICENSE_1_0.txt
// File Version: 3.0.1 (2018/10/05)

#pragma once

#include <LowLevel/GteLexicoArray2.h>
#include <cmath>
#include <vector>

namespace gte
{

template <typename Real>
class BandedMatrix
{
public:
    // Construction and destruction.
    ~BandedMatrix();
    BandedMatrix(int size, int numLBands, int numUBands);

    // Member access.
    inline int GetSize() const;
    inline std::vector<Real>& GetDBand();
    inline std::vector<Real> const& GetDBand() const;
    inline std::vector<std::vector<Real>>& GetLBands();
    inline std::vector<std::vector<Real>> const& GetLBands() const;
    inline std::vector<std::vector<Real>>& GetUBands();
    inline std::vector<std::vector<Real>> const& GetUBands() const;
    Real& operator()(int r, int c);
    Real const& operator()(int r, int c) const;

    // Factor the square banded matrix A into A = L*L^T, where L is a
    // lower-triangular matrix (L^T is an upper-triangular matrix).  This is
    // an LU decomposition that allows for stable inversion of A to solve
    // A*X = B.  The return value is 'true' iff the factorizing is successful
    // (L is invertible).  If successful, A contains the Cholesky
    // factorization: L in the lower-triangular part of A and L^T in the
    // upper-triangular part of A.
    bool CholeskyFactor();

    // Solve the linear system A*X = B, where A is an NxN banded matrix and B
    // is an Nx1 vector.  The unknown X is also Nx1.  The input to this
    // function is B.  The output X is computed and stored in B.  The return
    // value is 'true' iff the system has a solution.  The matrix A and the
    // vector B are both modified by this function.  If successful, A contains
    // the Cholesky factorization: L in the lower-triangular part of A and
    // L^T in the upper-triangular part of A.
    bool SolveSystem(Real* bVector);

    // Solve the linear system A*X = B, where A is an NxN banded matrix and
    // B is an NxM matrix.  The unknown X is also NxM.  The input to this
    // function is B.  The output X is computed and stored in B.  The return
    // value is 'true' iff the system has a solution.  The matrix A and the
    // vector B are both modified by this function.  If successful, A contains
    // the Cholesky factorization: L in the lower-triangular part of A and
    // L^T in the upper-triangular part of A.
    //
    // 'bMatrix' must have the storage order specified by the template
    // parameter.
    template <bool RowMajor>
    bool SolveSystem(Real* bMatrix, int numBColumns);

    // Compute the inverse of the banded matrix.  The return value is 'true'
    // when the matrix is invertible, in which case the 'inverse' output is
    // valid.  The return value is 'false' when the matrix is not invertible,
    // in which case 'inverse' is invalid and should not be used.  The
    // input matrix 'inverse' must be the same size as 'this'.
    //
    // 'bMatrix' must have the storage order specified by the template
    // parameter.
    template <bool RowMajor>
    bool ComputeInverse(Real* inverse) const;

private:
    // The linear system is L*U*X = B, where A = L*U and U = L^T,  Reduce this
    // to U*X = L^{-1}*B.  The return value is 'true' iff the operation is
    // successful.
    bool SolveLower(Real* dataVector) const;

    // The linear system is U*X = L^{-1}*B.  Reduce this to
    // X = U^{-1}*L^{-1}*B.  The return value is 'true' iff the operation is
    // successful.
    bool SolveUpper(Real* dataVector) const;

    // The linear system is L*U*X = B, where A = L*U and U = L^T,  Reduce this
    // to U*X = L^{-1}*B.  The return value is 'true' iff the operation is
    // successful.  See the comments for SolveSystem(Real*,int) about the
    // storage for dataMatrix.
    template <bool RowMajor>
    bool SolveLower(Real* dataMatrix, int numColumns) const;

    // The linear system is U*X = L^{-1}*B.  Reduce this to
    // X = U^{-1}*L^{-1}*B.  The return value is 'true' iff the operation is
    // successful.  See the comments for SolveSystem(Real*,int) about the
    // storage for dataMatrix.
    template <bool RowMajor>
    bool SolveUpper(Real* dataMatrix, int numColumns) const;

    int mSize;
    std::vector<Real> mDBand;
    std::vector<std::vector<Real>> mLBands, mUBands;

    // For return by operator()(int,int) for valid indices not in the bands,
    // in which case the matrix entries are zero,
    mutable Real mZero;
};


template <typename Real>
BandedMatrix<Real>::~BandedMatrix()
{
}

template <typename Real>
BandedMatrix<Real>::BandedMatrix(int size, int numLBands, int numUBands)
    :
    mSize(size),
    mZero((Real)0)
{
    if (size > 0
        && 0 <= numLBands && numLBands < size
        && 0 <= numUBands && numUBands < size)
    {
        mDBand.resize(size);
        std::fill(mDBand.begin(), mDBand.end(), (Real)0);

        if (numLBands > 0)
        {
            mLBands.resize(numLBands);
            int numElements = size - 1;
            for (auto& band : mLBands)
            {
                band.resize(numElements--);
                std::fill(band.begin(), band.end(), (Real)0);
            }
        }

        if (numUBands > 0)
        {
            mUBands.resize(numUBands);
            int numElements = size - 1;
            for (auto& band : mUBands)
            {
                band.resize(numElements--);
                std::fill(band.begin(), band.end(), (Real)0);
            }
        }
    }
    else
    {
        // Invalid argument to BandedMatrix constructor.
        mSize = 0;
    }
}

template <typename Real> inline
int BandedMatrix<Real>::GetSize() const
{
    return mSize;
}

template <typename Real> inline
std::vector<Real>& BandedMatrix<Real>::GetDBand()
{
    return mDBand;
}

template <typename Real> inline
std::vector<Real> const& BandedMatrix<Real>::GetDBand() const
{
    return mDBand;
}

template <typename Real> inline
std::vector<std::vector<Real>>& BandedMatrix<Real>::GetLBands()
{
    return mLBands;
}

template <typename Real> inline
std::vector<std::vector<Real>> const& BandedMatrix<Real>::GetLBands() const
{
    return mLBands;
}

template <typename Real> inline
std::vector<std::vector<Real>>& BandedMatrix<Real>::GetUBands()
{
    return mUBands;
}

template <typename Real> inline
std::vector<std::vector<Real>> const& BandedMatrix<Real>::GetUBands() const
{
    return mUBands;
}

template <typename Real>
Real& BandedMatrix<Real>::operator()(int r, int c)
{
    if (0 <= r && r < mSize && 0 <= c && c < mSize)
    {
        int band = c - r;
        if (band > 0)
        {
            int const numUBands = static_cast<int>(mUBands.size());
            if (--band < numUBands && r < mSize - 1 - band)
            {
                return mUBands[band][r];
            }
        }
        else if (band < 0)
        {
            band = -band;
            int const numLBands = static_cast<int>(mLBands.size());
            if (--band < numLBands && c < mSize - 1 - band)
            {
                return mLBands[band][c];
            }
        }
        else
        {
            return mDBand[r];
        }
    }
    // else invalid index


    // Set the value to zero in case someone unknowingly modified mZero on a
    // previous call to operator(int,int).
    mZero = (Real)0;
    return mZero;
}

template <typename Real>
Real const& BandedMatrix<Real>::operator()(int r, int c) const
{
    if (0 <= r && r < mSize && 0 <= c && c < mSize)
    {
        int band = c - r;
        if (band > 0)
        {
            int const numUBands = static_cast<int>(mUBands.size());
            if (--band < numUBands && r < mSize - 1 - band)
            {
                return mUBands[band][r];
            }
        }
        else if (band < 0)
        {
            band = -band;
            int const numLBands = static_cast<int>(mLBands.size());
            if (--band < numLBands && c < mSize - 1 - band)
            {
                return mLBands[band][c];
            }
        }
        else
        {
            return mDBand[r];
        }
    }
    // else invalid index


    // Set the value to zero in case someone unknowingly modified mZero on a
    // previous call to operator(int,int).
    mZero = (Real)0;
    return mZero;
}

template <typename Real>
bool BandedMatrix<Real>::CholeskyFactor()
{
    if (mDBand.size() == 0 || mLBands.size() != mUBands.size())
    {
        // Invalid number of bands.
        return false;
    }

    int const sizeM1 = mSize - 1;
    int const numBands = static_cast<int>(mLBands.size());

    int k, kMax;
    for (int i = 0; i < mSize; ++i)
    {
        int jMin = i - numBands;
        if (jMin < 0)
        {
            jMin = 0;
        }

        int j;
        for (j = jMin; j < i; ++j)
        {
            kMax = j + numBands;
            if (kMax > sizeM1)
            {
                kMax = sizeM1;
            }

            for (k = i; k <= kMax; ++k)
            {
                operator()(k, i) -= operator()(i, j)*operator()(k, j);
            }
        }

        kMax = j + numBands;
        if (kMax > sizeM1)
        {
            kMax = sizeM1;
        }

        for (k = 0; k < i; ++k)
        {
            operator()(k, i) = operator()(i, k);
        }

        Real diagonal = operator()(i, i);
        if (diagonal <= (Real)0)
        {
            return false;
        }
        Real invSqrt = ((Real)1) / std::sqrt(diagonal);
        for (k = i; k <= kMax; ++k)
        {
            operator()(k, i) *= invSqrt;
        }
    }

    return true;
}

template <typename Real>
bool BandedMatrix<Real>::SolveSystem(Real* bVector)
{
    return CholeskyFactor()
        && SolveLower(bVector)
        && SolveUpper(bVector);
}

template <typename Real>
template <bool RowMajor>
bool BandedMatrix<Real>::SolveSystem(Real* bMatrix, int numBColumns)
{
    return CholeskyFactor()
        && SolveLower<RowMajor>(bMatrix, numBColumns)
        && SolveUpper<RowMajor>(bMatrix, numBColumns);
}

template <typename Real>
template <bool RowMajor>
bool BandedMatrix<Real>::ComputeInverse(Real* inverse) const
{
    LexicoArray2<RowMajor, Real> invA(mSize, mSize, inverse);

    BandedMatrix<Real> tmpA = *this;
    for (int row = 0; row < mSize; ++row)
    {
        for (int col = 0; col < mSize; ++col)
        {
            if (row != col)
            {
                invA(row, col) = (Real)0;
            }
            else
            {
                invA(row, row) = (Real)1;
            }
        }
    }

    // Forward elimination.
    for (int row = 0; row < mSize; ++row)
    {
        // The pivot must be nonzero in order to proceed.
        Real diag = tmpA(row, row);
        if (diag == (Real)0)
        {
            return false;
        }

        Real invDiag = ((Real)1) / diag;
        tmpA(row, row) = (Real)1;

        // Multiply the row to be consistent with diagonal term of 1.
        int colMin = row + 1;
        int colMax = colMin + static_cast<int>(mUBands.size());
        if (colMax > mSize)
        {
            colMax = mSize;
        }

        int c;
        for (c = colMin; c < colMax; ++c)
        {
            tmpA(row, c) *= invDiag;
        }
        for (c = 0; c <= row; ++c)
        {
            invA(row, c) *= invDiag;
        }

        // Reduce the remaining rows.
        int rowMin = row + 1;
        int rowMax = rowMin + static_cast<int>(mLBands.size());
        if (rowMax > mSize)
        {
            rowMax = mSize;
        }

        for (int r = rowMin; r < rowMax; ++r)
        {
            Real mult = tmpA(r, row);
            tmpA(r, row) = (Real)0;
            for (c = colMin; c < colMax; ++c)
            {
                tmpA(r, c) -= mult*tmpA(row, c);
            }
            for (c = 0; c <= row; ++c)
            {
                invA(r, c) -= mult*invA(row, c);
            }
        }
    }

    // Backward elimination.
    for (int row = mSize - 1; row >= 1; --row)
    {
        int rowMax = row - 1;
        int rowMin = row - static_cast<int>(mUBands.size());
        if (rowMin < 0)
        {
            rowMin = 0;
        }

        for (int r = rowMax; r >= rowMin; --r)
        {
            Real mult = tmpA(r, row);
            tmpA(r, row) = (Real)0;
            for (int c = 0; c < mSize; ++c)
            {
                invA(r, c) -= mult*invA(row, c);
            }
        }
    }

    return true;
}

template <typename Real>
bool BandedMatrix<Real>::SolveLower(Real* dataVector) const
{
    int const size = static_cast<int>(mDBand.size());
    for (int r = 0; r < size; ++r)
    {
        Real lowerRR = operator()(r, r);
        if (lowerRR >(Real)0)
        {
            for (int c = 0; c < r; ++c)
            {
                Real lowerRC = operator()(r, c);
                dataVector[r] -= lowerRC * dataVector[c];
            }
            dataVector[r] /= lowerRR;
        }
        else
        {
            return false;
        }
    }
    return true;
}

template <typename Real>
bool BandedMatrix<Real>::SolveUpper(Real* dataVector) const
{
    int const size = static_cast<int>(mDBand.size());
    for (int r = size - 1; r >= 0; --r)
    {
        Real upperRR = operator()(r, r);
        if (upperRR > (Real)0)
        {
            for (int c = r + 1; c < size; ++c)
            {
                Real upperRC = operator()(r, c);
                dataVector[r] -= upperRC * dataVector[c];
            }

            dataVector[r] /= upperRR;
        }
        else
        {
            return false;
        }
    }
    return true;
}

template <typename Real>
template <bool RowMajor>
bool BandedMatrix<Real>::SolveLower(Real* dataMatrix, int numColumns) const
{
    LexicoArray2<RowMajor, Real> data(mSize, numColumns, dataMatrix);

    for (int r = 0; r < mSize; ++r)
    {
        Real lowerRR = operator()(r, r);
        if (lowerRR >(Real)0)
        {
            for (int c = 0; c < r; ++c)
            {
                Real lowerRC = operator()(r, c);
                for (int bCol = 0; bCol < numColumns; ++bCol)
                {
                    data(r, bCol) -= lowerRC * data(c, bCol);
                }
            }

            Real inverse = ((Real)1) / lowerRR;
            for (int bCol = 0; bCol < numColumns; ++bCol)
            {
                data(r, bCol) *= inverse;
            }
        }
        else
        {
            return false;
        }
    }
    return true;
}

template <typename Real>
template <bool RowMajor>
bool BandedMatrix<Real>::SolveUpper(Real* dataMatrix, int numColumns) const
{
    LexicoArray2<RowMajor, Real> data(mSize, numColumns, dataMatrix);

    for (int r = mSize - 1; r >= 0; --r)
    {
        Real upperRR = operator()(r, r);
        if (upperRR > (Real)0)
        {
            for (int c = r + 1; c < mSize; ++c)
            {
                Real upperRC = operator()(r, c);
                for (int bCol = 0; bCol < numColumns; ++bCol)
                {
                    data(r, bCol) -= upperRC * data(c, bCol);
                }
            }

            Real inverse = ((Real)1) / upperRR;
            for (int bCol = 0; bCol < numColumns; ++bCol)
            {
                data(r, bCol) *= inverse;
            }
        }
        else
        {
            return false;
        }
    }
    return true;
}


}
